{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from fastai.vision import *\n",
    "torch.cuda.set_device(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "a = torch.randn(2, 32, 224, 112).cuda()\n",
    "b = torch.randn(2, 64, 112, 56).cuda()\n",
    "out = torch.zeros(32, 64).cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "140 ms ± 10.2 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
     ]
    }
   ],
   "source": [
    "%%timeit\n",
    "a = torch.randn(112, 56, 32).cuda()\n",
    "b = torch.randn(112, 56, 64).cuda()\n",
    "out = torch.zeros(32, 64).cuda()\n",
    "h = a.shape[0]\n",
    "w = a.shape[1]\n",
    "for i in range(out.shape[0]) :\n",
    "    for j in range(out.shape[1]) : \n",
    "        out[i, j] = torch.sum(a[:, :, i] * b[:, :, j])\n",
    "out = (1 / (h * w)) * out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "8.17 ms ± 551 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
     ]
    }
   ],
   "source": [
    "%%timeit\n",
    "a = torch.randn(112, 56, 32).cuda()\n",
    "b = torch.randn(112, 56, 64).cuda()\n",
    "out = torch.zeros(32, 64).cuda()\n",
    "h = a.shape[0]\n",
    "w = a.shape[1]\n",
    "out2 = (1 / (h * w)) * torch.einsum(\"mnr, mnd->rd\", a, b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 2.0728e-02,  8.2371e-03, -1.5353e-02,  ..., -1.5806e-02,\n",
       "          1.2172e-03, -1.6790e-02],\n",
       "        [-1.4771e-02, -1.0435e-02, -5.8164e-03,  ...,  1.4146e-02,\n",
       "          1.4725e-02,  7.8172e-03],\n",
       "        [-1.4103e-02,  1.1401e-02,  7.5380e-04,  ...,  4.3215e-03,\n",
       "          3.1517e-03, -1.0627e-02],\n",
       "        ...,\n",
       "        [-3.2658e-03, -1.5833e-02,  1.6929e-02,  ..., -2.2531e-03,\n",
       "         -9.6723e-03,  5.1153e-03],\n",
       "        [-3.8121e-02, -1.0297e-02,  1.2941e-02,  ..., -8.0578e-05,\n",
       "          8.0383e-03, -1.9793e-02],\n",
       "        [-2.6047e-02,  8.7224e-05,  1.5107e-02,  ...,  1.9547e-02,\n",
       "          3.5168e-03,  1.0607e-02]], device='cuda:0')"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 2.0728e-02,  8.2371e-03, -1.5353e-02,  ..., -1.5806e-02,\n",
       "          1.2172e-03, -1.6790e-02],\n",
       "        [-1.4771e-02, -1.0435e-02, -5.8164e-03,  ...,  1.4146e-02,\n",
       "          1.4725e-02,  7.8172e-03],\n",
       "        [-1.4103e-02,  1.1401e-02,  7.5380e-04,  ...,  4.3215e-03,\n",
       "          3.1517e-03, -1.0627e-02],\n",
       "        ...,\n",
       "        [-3.2658e-03, -1.5833e-02,  1.6929e-02,  ..., -2.2531e-03,\n",
       "         -9.6723e-03,  5.1153e-03],\n",
       "        [-3.8121e-02, -1.0297e-02,  1.2941e-02,  ..., -8.0572e-05,\n",
       "          8.0383e-03, -1.9793e-02],\n",
       "        [-2.6047e-02,  8.7224e-05,  1.5107e-02,  ...,  1.9547e-02,\n",
       "          3.5168e-03,  1.0607e-02]], device='cuda:0')"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out2 * (1 / (h * w))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([6272, 32])\n",
      "torch.Size([6272, 64])\n"
     ]
    }
   ],
   "source": [
    "print(a.reshape(-1, 32).shape)\n",
    "print(b.reshape(-1, 64).shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "7.71 ms ± 369 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
     ]
    }
   ],
   "source": [
    "%%timeit\n",
    "a = torch.randn(112, 56, 32).cuda()\n",
    "b = torch.randn(112, 56, 64).cuda()\n",
    "out = torch.zeros(32, 64).cuda()\n",
    "h = a.shape[0]\n",
    "w = a.shape[1]\n",
    "out3 = a.reshape(-1, 32).transpose(1, 0) @ b.reshape(-1, 64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "def fsp(a, b) :\n",
    "    a = F.adaptive_max_pool2d(a, output_size = b.shape[2 : ])\n",
    "    return (1 / (a.shape[1] * a.shape[2])) * torch.einsum(\"...rmn, ...dmn->...rd\", a, b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "85.4 µs ± 24.9 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
     ]
    }
   ],
   "source": [
    "%%timeit\n",
    "# a = torch.randn(2, 112, 56, 32).cuda()\n",
    "# b = torch.randn(2, 112, 56, 64).cuda()\n",
    "out = fsp(a, b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 32, 112, 56])"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python (ak_fastai)",
   "language": "python",
   "name": "ak_fastai"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
